{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "071c7845",
   "metadata": {},
   "source": [
    "下面介绍基于循环神经网络的编码器和解码器的代码实现。首先是作为编码器的循环神经网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f682c7d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class RNNEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNEncoder, self).__init__()\n",
    "        # 隐层大小\n",
    "        self.hidden_size = hidden_size\n",
    "        # 词表大小\n",
    "        self.vocab_size = vocab_size\n",
    "        # 词嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size,\\\n",
    "            self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        # inputs: batch * seq_len\n",
    "        # 注意门控循环单元使用batch_first=True，因此输入需要至少batch为1\n",
    "        features = self.embedding(inputs)\n",
    "        output, hidden = self.gru(features)\n",
    "        return output, hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dc8af74",
   "metadata": {},
   "source": [
    "接下来是作为解码器的另一个循环神经网络的代码实现。<!--我们使用编码器最终的输出用作解码器的初始隐状态，这个输出向量有时称作上下文向量（context vector），它编码了整个源序列的信息。解码器最初的输入词元是“\\<sos\\>”（start-of-string）。解码器的目标为，输入编码器隐状态，一步一步解码出整个目标序列。解码器每一步的输入可以是真实目标序列，也可以是解码器上一步的预测结果。-->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9beee561",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        # 序列到序列任务并不限制编码器和解码器输入同一种语言，\n",
    "        # 因此解码器也需要定义一个嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将输出的隐状态映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1,\\\n",
    "            dtype=torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden = self.forward_step(\\\n",
    "                decoder_input, decoder_hidden)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 为了与AttnRNNDecoder接口保持统一，最后输出None\n",
    "        return decoder_outputs, decoder_hidden, None\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden):\n",
    "        output = self.embedding(input)\n",
    "        output = F.relu(output)\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38306ed9",
   "metadata": {},
   "source": [
    "下面介绍基于注意力机制的循环神经网络解码器的代码实现。\n",
    "我们使用一个注意力层来计算注意力权重，其输入为解码器的输入和隐状态。\n",
    "这里使用Bahdanau注意力（Bahdanau attention），这是序列到序列模型中应用最广泛的注意力机制，特别是机器翻译任务。该注意力机制使用一个对齐模型（alignment model）来计算编码器和解码器隐状态之间的注意力分数，具体来讲就是一个前馈神经网络。相比于点乘注意力，Bahdanau注意力利用了非线性变换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d675aa6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class BahdanauAttention(nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.Wa = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Ua = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Va = nn.Linear(hidden_size, 1)\n",
    "\n",
    "    def forward(self, query, keys):\n",
    "        # query: batch * 1 * hidden_size\n",
    "        # keys: batch * seq_length * hidden_size\n",
    "        # 这一步用到了广播（broadcast）机制\n",
    "        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
    "        scores = scores.squeeze(2).unsqueeze(1)\n",
    "\n",
    "        weights = F.softmax(scores, dim=-1)\n",
    "        context = torch.bmm(weights, keys)\n",
    "        return context, weights\n",
    "\n",
    "class AttnRNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(AttnRNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.attention = BahdanauAttention(hidden_size)\n",
    "        # 输入来自解码器输入和上下文向量，因此输入大小为2 * hidden_size\n",
    "        self.gru = nn.GRU(2 * self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将注意力的结果映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1, dtype=\\\n",
    "            torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        attentions = []\n",
    "\n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden, attn_weights = \\\n",
    "                self.forward_step(\n",
    "                    decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            attentions.append(attn_weights)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        attentions = torch.cat(attentions, dim=1)\n",
    "        # 与RNNDecoder接口保持统一，最后输出注意力权重\n",
    "        return decoder_outputs, decoder_hidden, attentions\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden, encoder_outputs):\n",
    "        embeded =  self.embedding(input)\n",
    "        # 输出的隐状态为1 * batch * hidden_size，\n",
    "        # 注意力的输入需要batch * 1 * hidden_size\n",
    "        query = hidden.permute(1, 0, 2)\n",
    "        context, attn_weights = self.attention(query, encoder_outputs)\n",
    "        input_gru = torch.cat((embeded, context), dim=2)\n",
    "        # 输入的隐状态需要1 * batch * hidden_size\n",
    "        output, hidden = self.gru(input_gru, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden, attn_weights\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d63b031",
   "metadata": {},
   "source": [
    "\n",
    "接下来我们实现基于Transformer的编码器和解码器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c9a91cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('./code')\n",
    "from transformer import *\n",
    "\n",
    "class TransformerEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 直接使用TransformerLayer作为编码层，简单起见只使用一层\n",
    "        self.layer = TransformerLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 与TransformerLM不同，编码器不需要线性层用于输出\n",
    "        \n",
    "    def forward(self, input_ids):\n",
    "        # 这里实现的forward()函数一次只能处理一句话，\n",
    "        # 如果想要支持批次运算，需要根据输入序列的长度返回隐状态\n",
    "        assert input_ids.ndim == 2 and input_ids.size(0) == 1\n",
    "        seq_len = input_ids.size(1)\n",
    "        assert seq_len <= self.embedding_layer.max_len\n",
    "        \n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        attention_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        input_states = self.embedding_layer(input_ids, pos_ids)\n",
    "        hidden_states = self.layer(input_states, attention_mask)\n",
    "        return hidden_states, attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5e8cfc9c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "class MultiHeadCrossAttention(MultiHeadSelfAttention):\n",
    "    def forward(self, tgt, tgt_mask, src, src_mask):\n",
    "        \"\"\"\n",
    "        tgt: query, batch_size * tgt_seq_len * hidden_size\n",
    "        tgt_mask: batch_size * tgt_seq_len\n",
    "        src: keys/values, batch_size * src_seq_len * hidden_size\n",
    "        src_mask: batch_size * src_seq_len\n",
    "        \"\"\"\n",
    "        # (batch_size * num_heads) * seq_len * (hidden_size / num_heads)\n",
    "        queries = self.transpose_qkv(self.W_q(tgt))\n",
    "        keys = self.transpose_qkv(self.W_k(src))\n",
    "        values = self.transpose_qkv(self.W_v(src))\n",
    "        # 这一步与自注意力不同，计算交叉掩码\n",
    "        # batch_size * tgt_seq_len * src_seq_len\n",
    "        attention_mask = tgt_mask.unsqueeze(2) * src_mask.unsqueeze(1)\n",
    "        # 重复张量的元素，用以支持多个注意力头的运算\n",
    "        # (batch_size * num_heads) * tgt_seq_len * src_seq_len\n",
    "        attention_mask = torch.repeat_interleave(attention_mask,\\\n",
    "            repeats=self.num_heads, dim=0)\n",
    "        # (batch_size * num_heads) * tgt_seq_len * \\\n",
    "        # (hidden_size / num_heads)\n",
    "        output = self.attention(queries, keys, values, attention_mask)\n",
    "        # batch * tgt_seq_len * hidden_size\n",
    "        output_concat = self.transpose_output(output)\n",
    "        return self.W_o(output_concat)\n",
    "\n",
    "# TransformerDecoderLayer比TransformerLayer多了交叉多头注意力\n",
    "class TransformerDecoderLayer(nn.Module):\n",
    "    def __init__(self, hidden_size, num_heads, dropout,\\\n",
    "                 intermediate_size):\n",
    "        super().__init__()\n",
    "        self.self_attention = MultiHeadSelfAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm1 = AddNorm(hidden_size, dropout)\n",
    "        self.enc_attention = MultiHeadCrossAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm2 = AddNorm(hidden_size, dropout)\n",
    "        self.fnn = PositionWiseFNN(hidden_size, intermediate_size)\n",
    "        self.add_norm3 = AddNorm(hidden_size, dropout)\n",
    "\n",
    "    def forward(self, src_states, src_mask, tgt_states, tgt_mask):\n",
    "        # 掩码多头自注意力\n",
    "        tgt = self.add_norm1(tgt_states, self.self_attention(\\\n",
    "            tgt_states, tgt_states, tgt_states, tgt_mask))\n",
    "        # 交叉多头自注意力\n",
    "        tgt = self.add_norm2(tgt, self.enc_attention(tgt,\\\n",
    "            tgt_mask, src_states, src_mask))\n",
    "        # 前馈神经网络\n",
    "        return self.add_norm3(tgt, self.fnn(tgt))\n",
    "\n",
    "class TransformerDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "                 dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 简单起见只使用一层\n",
    "        self.layer = TransformerDecoderLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 解码器与TransformerLM一样，需要输出层\n",
    "        self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
    "        \n",
    "    def forward(self, src_states, src_mask, tgt_tensor=None):\n",
    "        # 确保一次只输入一句话，形状为1 * seq_len * hidden_size\n",
    "        assert src_states.ndim == 3 and src_states.size(0) == 1\n",
    "        \n",
    "        if tgt_tensor is not None:\n",
    "            # 确保一次只输入一句话，形状为1 * seq_len\n",
    "            assert tgt_tensor.ndim == 2 and tgt_tensor.size(0) == 1\n",
    "            seq_len = tgt_tensor.size(1)\n",
    "            assert seq_len <= self.embedding_layer.max_len\n",
    "        else:\n",
    "            seq_len = self.embedding_layer.max_len\n",
    "        \n",
    "        decoder_input = torch.empty(1, 1, dtype=torch.long).\\\n",
    "            fill_(SOS_token)\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        for i in range(seq_len):\n",
    "            decoder_output = self.forward_step(decoder_input,\\\n",
    "                src_mask, src_states)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            \n",
    "            if tgt_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    tgt_tensor[:, i:i+1]), 1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    topi.squeeze(-1).detach()), 1)\n",
    "                \n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 与RNNDecoder接口保持统一\n",
    "        return decoder_outputs, None, None\n",
    "        \n",
    "    # 解码一步，与RNNDecoder接口略有不同，RNNDecoder一次输入\n",
    "    # 一个隐状态和一个词，输出一个分布、一个隐状态\n",
    "    # TransformerDecoder不需要输入隐状态，\n",
    "    # 输入整个目标端历史输入序列，输出一个分布，不输出隐状态\n",
    "    def forward_step(self, tgt_inputs, src_mask, src_states):\n",
    "        seq_len = tgt_inputs.size(1)\n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        tgt_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        tgt_states = self.embedding_layer(tgt_inputs, pos_ids)\n",
    "        hidden_states = self.layer(src_states, src_mask, tgt_states,\\\n",
    "            tgt_mask)\n",
    "        output = self.output_layer(hidden_states[:, -1:, :])\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e58b811",
   "metadata": {},
   "source": [
    "下面以机器翻译（中-英）为例展示如何训练序列到序列模型。这里使用的是中英文Books数据，其中中文标题来源于第4章所使用的数据集，英文标题是使用已训练好的机器翻译模型从中文标题翻译而得，因此该数据并不保证准确性，仅用于演示。\n",
    "\n",
    "首先需要对源语言和目标语言分别建立索引，并记录词频。\n",
    "\n",
    "<!-- 代码来自：\n",
    "\n",
    "https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html <span style=\"color:blue;font-size:20px\">BSD-3-Clause license</span>\n",
    "\n",
    "下载链接：\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.en\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.zh\n",
    " -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6b258fc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "SOS_token = 0\n",
    "EOS_token = 1\n",
    "\n",
    "class Lang:\n",
    "    def __init__(self, name):\n",
    "        self.name = name\n",
    "        self.word2index = {}\n",
    "        self.word2count = {}\n",
    "        self.index2word = {0: \"<sos>\", 1: \"<eos>\"}\n",
    "        self.n_words = 2  # Count SOS and EOS\n",
    "\n",
    "    def addSentence(self, sentence):\n",
    "        for word in sentence.split(' '):\n",
    "            self.addWord(word)\n",
    "\n",
    "    def addWord(self, word):\n",
    "        if word not in self.word2index:\n",
    "            self.word2index[word] = self.n_words\n",
    "            self.word2count[word] = 1\n",
    "            self.index2word[self.n_words] = word\n",
    "            self.n_words += 1\n",
    "        else:\n",
    "            self.word2count[word] += 1\n",
    "            \n",
    "    def sent2ids(self, sent):\n",
    "        return [self.word2index[word] for word in sent.split(' ')]\n",
    "    \n",
    "    def ids2sent(self, ids):\n",
    "        return ' '.join([self.index2word[idx] for idx in ids])\n",
    "\n",
    "import unicodedata\n",
    "import string\n",
    "import re\n",
    "import random\n",
    "\n",
    "# 文件使用unicode编码，我们将unicode转为ASCII，转为小写，并修改标点\n",
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "    )\n",
    "\n",
    "def normalizeString(s):\n",
    "    s = unicodeToAscii(s.lower().strip())\n",
    "    # 在标点前插入空格\n",
    "    s = re.sub(r\"([,.!?])\", r\" \\1\", s)\n",
    "    return s.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "01ae1e7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取文件，一共有两个文件，两个文件的同一行对应一对源语言和目标语言句子\n",
    "def readLangs(lang1, lang2):\n",
    "    # 读取文件，分句\n",
    "    lines1 = open(f'{lang1}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    lines2 = open(f'{lang2}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    print(len(lines1), len(lines2))\n",
    "    \n",
    "    # 规范化\n",
    "    lines1 = [normalizeString(s) for s in lines1]\n",
    "    lines2 = [normalizeString(s) for s in lines2]\n",
    "    if lang1 == 'zh':\n",
    "        lines1 = [' '.join(list(s.replace(' ', ''))) for s in lines1]\n",
    "    if lang2 == 'zh':\n",
    "        lines2 = [' '.join(list(s.replace(' ', ''))) for s in lines2]\n",
    "    pairs = [[l1, l2] for l1, l2 in zip(lines1, lines2)]\n",
    "\n",
    "    input_lang = Lang(lang1)\n",
    "    output_lang = Lang(lang2)\n",
    "    return input_lang, output_lang, pairs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d6a561e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n",
      "['孙 正 义 的 时 间 管 理 术', 'sun just time management']\n"
     ]
    }
   ],
   "source": [
    "# 为了快速训练，过滤掉一些过长的句子\n",
    "MAX_LENGTH = 30\n",
    "\n",
    "def filterPair(p):\n",
    "    return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
    "        len(p[1].split(' ')) < MAX_LENGTH\n",
    "\n",
    "def filterPairs(pairs):\n",
    "    return [pair for pair in pairs if filterPair(pair)]\n",
    "\n",
    "def prepareData(lang1, lang2):\n",
    "    input_lang, output_lang, pairs = readLangs(lang1, lang2)\n",
    "    print(f\"读取 {len(pairs)} 对序列\")\n",
    "    pairs = filterPairs(pairs)\n",
    "    print(f\"过滤后剩余 {len(pairs)} 对序列\")\n",
    "    print(\"统计词数\")\n",
    "    for pair in pairs:\n",
    "        input_lang.addSentence(pair[0])\n",
    "        output_lang.addSentence(pair[1])\n",
    "    print(input_lang.name, input_lang.n_words)\n",
    "    print(output_lang.name, output_lang.n_words)\n",
    "    return input_lang, output_lang, pairs\n",
    "\n",
    "input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "print(random.choice(pairs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfa6874b",
   "metadata": {},
   "source": [
    "为了便于训练，对每一对源-目标句子需要准备一个源张量（源句子的词元索引）和一个目标张量（目标句子的词元索引）。在两个句子的末尾会添加“\\<eos\\>”。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "92baa7c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n"
     ]
    }
   ],
   "source": [
    "def get_train_data():\n",
    "    input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "    train_data = []\n",
    "    for idx, (src_sent, tgt_sent) in enumerate(pairs):\n",
    "        src_ids = input_lang.sent2ids(src_sent)\n",
    "        tgt_ids = output_lang.sent2ids(tgt_sent)\n",
    "        # 添加<eos>\n",
    "        src_ids.append(EOS_token)\n",
    "        tgt_ids.append(EOS_token)\n",
    "        train_data.append([src_ids, tgt_ids])\n",
    "    return input_lang, output_lang, train_data\n",
    "        \n",
    "input_lang, output_lang, train_data = get_train_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "411c9ed3",
   "metadata": {},
   "source": [
    "<!--训练时，我们使用编码器编码源句子，并且保留每一步的输出向量和最终的隐状态。然后解码器以“\\<sos\\>”作为第一个输入，以及编码器的最终隐状态作为它的第一个隐状态。-->\n",
    "接下来是训练代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8b85d1ba",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch-19, loss=0.1184: 100%|█| 20/20 [38:39<00:00, 115.98s/i\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGwCAYAAACHJU4LAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABM8klEQVR4nO3deVxU5f4H8M+wzLBvsongiooi7kuoueSumVm3zKzUtl9lmZmWVpZmha3XUq91y9TuTb0tagu57/uCoLiEoiioLC7AsA4w8/z+QI4MzMAAw5wZ+Lxfr3m9Zs555pnvaSA+Puc5z1EIIQSIiIiIrJCd3AUQERERGcOgQkRERFaLQYWIiIisFoMKERERWS0GFSIiIrJaDCpERERktRhUiIiIyGo5yF1AXeh0Oly/fh3u7u5QKBRyl0NEREQmEEIgJycHQUFBsLOreszEpoPK9evXERISIncZREREVAspKSkIDg6uso1NBxV3d3cApQfq4eEhczVERERkCrVajZCQEOnveFVsOqiUne7x8PBgUCEiIrIxpkzb4GRaIiIisloMKkRERGS1GFSIiIjIajGoEBERkdViUCEiIiKrxaBCREREVotBhYiIiKwWgwoRERFZLQYVIiIisloMKkRERGS1GFSIiIjIajGoEBERkdViUDGioEgLIYTcZRARETVqDCoGnE/PQYd3N+OtDfFyl0JERNSoMagYsGxXIgBg7dEUmSshIiJq3BhUDMhQa+QugYiIiMCgYtAjPYPlLoGIiIggc1DRarWYN28eWrVqBWdnZ7Rp0wYLFy6UfRJrsLcLAKC1r6usdRARETV2DnJ++Mcff4zly5dj9erVCA8Px/HjxzF16lR4enpi+vTpstVlb6cAAJToeNUPERGRnGQNKgcPHsS4ceMwZswYAEDLli2xdu1aHD161GB7jUYDjebu/BG1Wl0vdTncCSpaBhUiIiJZyXrqp2/fvtixYwfOnz8PADh58iT279+PUaNGGWwfFRUFT09P6RESElIvdd0dUdHVS/9ERERkGllHVObMmQO1Wo2wsDDY29tDq9Xiww8/xKRJkwy2nzt3LmbOnCm9VqvV9RJWHOzvBBUtR1SIiIjkJGtQ+emnn/Djjz9izZo1CA8PR1xcHGbMmIGgoCBMnjy5UnuVSgWVSlXvdTlwjgoREZFVkDWozJ49G3PmzMFjjz0GAIiIiMCVK1cQFRVlMKhYir1d6RkxzlEhIiKSl6xzVPLz82Fnp1+Cvb09dDLPDXHgHBUiIiKrIOuIytixY/Hhhx+iefPmCA8PR2xsLL744gs8/fTTcpYlTabliAoREZG8ZA0qS5Yswbx58/DSSy8hIyMDQUFB+L//+z+8++67cpZ1dzItgwoREZGsZA0q7u7uWLx4MRYvXixnGZU43DkdJQSg0wnY3RlhISIiIsvivX4MsC8XTDiqQkREJB8GFQMcygUVzlMhIiKSD4OKAfojKrzyh4iISC4MKgZwRIWIiMg6MKgYwDkqRERE1oFBxQCFQsG1VIiIiKwAg4oR9rzfDxERkewYVIwom6ei5R2UiYiIZMOgYoQ97/dDREQkOwYVIxw4R4WIiEh2DCpGcI4KERGR/BhUjFAoSoOKTjCoEBERyYVBxYiypVSYU4iIiOTDoGKEAqVJhUGFiIhIPgwqRkgjKmBSISIikguDihF356jIXAgREVEjxqBixJ2cwsm0REREMmJQMcJOwTkqREREcmNQMUIhXfXDpEJERCQXBhUjpBEVmesgIiJqzBhUjJDmqHA2LRERkWwYVIy4k1N41Q8REZGMGFSMuHvqh0mFiIhILgwqRii4hD4REZHsGFSM4OXJRERE8mNQMYJ3TyYiIpIfg4oRdyfTMqgQERHJhUHFCLs7/2UYU4iIiOTDoGKEAmVzVBhViIiI5MKgYoQdr/ohIiKSnaxBpWXLllAoFJUe06ZNk7MsAOUn08pcCBERUSPmIOeHHzt2DFqtVnp9+vRpDBs2DI888oiMVZWSltDnkAoREZFsZA0qfn5+eq8XLVqENm3aYODAgQbbazQaaDQa6bVara632riOChERkfysZo5KUVER/vvf/+Lpp5+WTrtUFBUVBU9PT+kREhJSb/WUVcDJtERERPKxmqCyceNGZGVlYcqUKUbbzJ07F9nZ2dIjJSWl3uq5e68fIiIikousp37KW7FiBUaNGoWgoCCjbVQqFVQqlUXq4RwVIiIi+VlFULly5Qq2b9+O9evXy12K5G5QkbcOIiKixswqTv2sXLkS/v7+GDNmjNylSO5OpmVSISIikovsQUWn02HlypWYPHkyHBysYoAHwN0RFeYUIiIi+cgeVLZv347k5GQ8/fTTcpei5+5kWiYVIiIiucg+hDF8+HCrPL0irUyrk7kQIiKiRkz2ERVrVbaOCq/6ISIikg+DihHSTQnlLYOIiKhRY1AxQsGrfoiIiGTHoGKEHa/6ISIikh2DihHSZFoGFSIiItkwqBjBybRERETyY1AxgjclJCIikh+DihHaOyMpqw4kyVwJERFR48WgYsS2s+kAgIs38mSuhIiIqPFiUCEiIiKrxaBCREREVotBhYiIiKwWgwoRERFZLQYVIiIisloMKib46K9zcpdARETUKDGomODfey/JXQIREVGjxKBixNguQXKXQERE1OgxqBgxbXAbuUsgIiJq9BhUjHC0538aIiIiufGvsRGezo5yl0BERNToMagY4eum0nt9K1cjUyVERESNF4OKiXp8sF3uEoiIiBodBpUauHQjV+4SiIiIGhUGlRq47/M90JRo5S6DiIio0WBQqaH/HUuRuwQiIqJGg0GlCoseiqi0LTu/WIZKiIiIGicGlSr0bOktdwlERESNGoNKFTycKq+l8vm288jMK5KhGiIiosaHQaUKrioHg9uX77lo4UqIiIgaJwaVKjg72hvczrspExERWYbsQeXatWt44okn0KRJEzg7OyMiIgLHjx+XuywAgJ2dQu4SiIiIGjXD5zYsJDMzE/369cPgwYOxadMm+Pn54cKFC/D25iRWIiIikjmofPzxxwgJCcHKlSulba1atZKxosr+Nak7XvrxhMF92fnF8HThzQuJiIjqi6ynfn7//Xf07NkTjzzyCPz9/dGtWzd8++23RttrNBqo1Wq9R33LLjC8bkrLOdHo8v5WbIpPrfcaiIiIGitZg8qlS5ewfPlytG3bFlu2bMGLL76I6dOnY/Xq1QbbR0VFwdPTU3qEhITUe42bTqdVuf9FI6MtREREVHcKIYSQ68OVSiV69uyJgwcPStumT5+OY8eO4dChQ5XaazQaaDQa6bVarUZISAiys7Ph4eFRLzVGbTqHb/ZUfZXP5UVj6uWziYiIGiK1Wg1PT0+T/n7LOqLStGlTdOzYUW9bhw4dkJycbLC9SqWCh4eH3qO+dQvxqvF7hBCQMf8RERE1GLIGlX79+iEhIUFv2/nz59GiRQuZKqrsWlZhjd/z9Kpj6PnBdhRrdfVQERERUeMha1B57bXXcPjwYXz00UdITEzEmjVr8O9//xvTpk2Tsyw9poyM5GlK9F7vSriBW3lFWMc7LRMREdWJrEGlV69e2LBhA9auXYtOnTph4cKFWLx4MSZNmiRnWXqaeTlX2+bd385Iz8sHm99ir9VLTURERI2FrOuoAMD999+P+++/X+4yjDJlpsmvJ67CTWWPBeM6QVfuDZylQkREVDeyL6Fv7dydTMtyqw9dAQCU6DgvhYiIyFwYVKrRP9QXDjW450+J9u44SsyVTNzM1VTRmoiIiKrCoFINhUKBlVN7mdRWCIESnf4Jn7nr4+ujLCIiokaBQcUE/UN9MSTMv9p2p6+poa0QVLadTcfrP52sr9KIiIgaNAYVEygUCjzdv/qbJeYVlaCwWFtp+68nrtZHWURERA0eg4qJFCZMUyko0mLtUcOr6gJArsZwkCEiIiLDZL882VbYm5BUZv4Uh8x8w3db/i3uGl5dFwd3JwfEzx9h7vKIiIgaJI6omEhhQlAxFlIA4NV1cQCAnMISo22IiIhIH4OKibILjIeQmuINC4mIiEzDoGKi61kFZuvrSNJts/VFRETUkDGomMjXTWW2vnJ5+oeIiMgkDComMuWqH1PpeOqHiIjIJAwqJqrBKvrV0jGnEBERmYRBxWR3k8qwjgF16okjKkRERKZhUDFR52BP6fm2s+l16qviMvtERERkGIOKiYK8nLFr1iDEvDO0zn1xRIWIiMg0XJm2Blr5upqlH+YUIiIi03BERQYCTCpERESmYFCRAUdUiIiITMOgIgMGFSIiItMwqBAREZHVYlCRgZeLo9wlEBER2QQGFRmoHOzlLoGIiMgmMKjIgOuoEBERmYZBRQYMKkRERKZhUJGBurBE7hKIiIhsAoOKDKavjcXtvCK5yyAiIrJ6DCoy+WbvRblLICIisnoMKrXw3L2t6txHUYnODJUQERE1bAwqtTBjaDs80iMYK6f2qnUfOh0n1BIREVVH1qAyf/58KBQKvUdYWJicJZnEVeWATx/pgsHt/eUuhYiIqEFzkLuA8PBwbN++XXrt4CB7SRahUCjkLoGIiMjqyX7qx8HBAYGBgdLD19dX7pLMYuGDneQugYiIyObJHlQuXLiAoKAgtG7dGpMmTUJycrLRthqNBmq1Wu9hrSKaecpdAhERkc2TNaj06dMHq1atwubNm7F8+XIkJSXh3nvvRU5OjsH2UVFR8PT0lB4hISEWrriy2SPao1Mzj0rbwwLdq3xfVj7XUSEiIqqOQgjrWc89KysLLVq0wBdffIFnnnmm0n6NRgONRiO9VqvVCAkJQXZ2Njw8KocFS2o5J1p6fmbBCLiqHPS2VTQkzB8rptT+qiEiIiJbpVar4enpadLfb6uauerl5YV27dohMTHR4H6VSgWVSmXhqmrOVVX9f1arSYdERERWTPY5KuXl5ubi4sWLaNq0qdyl1DsrGsgiIiKyWrIGlVmzZmHPnj24fPkyDh48iPHjx8Pe3h4TJ06UsyyL4HpvRERE1ZP11M/Vq1cxceJE3Lp1C35+fujfvz8OHz4MPz8/OcuyiFNXs+QugYiIyOrJGlTWrVsn58fXuw0v9cX4fx00uE9dWGLhaoiIiGyPVc1RsWXeLo6VtnVr7o1FD0UYbK/luR8iIqJqMaiYibOjvcHtj/VubnD7mM5NeWNCIiKiajComMk3T/ZEyyYu+ObJHia1jz6VijFL9jOsEBERVYFBxUwigj2xe/ZgjAgPNPk951LVuJWnv0Jtdn4xHl5+ECv2J5m7RCIiIpvDoCKz5Nt5eq/XHE1GzJVMLPzzrMH2Wp3Ao18fwuTvj+K3uGsoLNZaokwiIiJZMKjI7O80/fsa3crVGGlZ1l6No5dvY8/5G3h1XRwWbfq7PssjIiKSFYOKzIpKdHV6/5+nrpupEiIiIuvDoCKzLiFeNWpfceX9m7lFSFcXmq8gIiIiK8KgYgGP9AgGALx7f8dK++wUiirfuyk+FSMX78WF9ByjbcYvO1C3AomIiKwUg4oFfPpIF1xeNAb9Qn0r7dPqqj718+KPJ/B3Wg5eXRcHoPKICgBcz+aIChERNUwMKhbUPtAdvm5KvW3FWtPWUcnVlC65L8B1V4iIqPFgULGwoR0C9F6bupS+ztBQChERUQPHoGJhZSMjZUoqBBVjU1bKcgrzChERNSYMKhZWrNWfk1KirdnlycwpRETUmDCoWFjFXFKiE1ixPwm/nyxdD6X8iMnQL/ZIz8tO/QgOqRARUSPiIHcBjU3FuSbn03Lw+bbzAIAHugThu3L3+EnMyJWel72t4qkiIiKihowjKhZW8dTPzXJL5ld13540dSFu5xXh5+Mp9VYbERGRteGIioVVvMqn/MvdCRlVvrf7wm31URIREZHV4oiKhVU89bPmaLL0PKewpGJzIiKiRo1BxcIqjqiUfz37l1OWLoeIiMiqMahYGCfDEhERmY5BxcLyNcYnzNaFVifw5fYLOHjxZr30T0REJIdaBZXVq1cjOjpaev3GG2/Ay8sLffv2xZUrV8xWXEOUUMVdkOti/u9n8M/t5/H4t0fqpX8iIiI51CqofPTRR3B2dgYAHDp0CMuWLcMnn3wCX19fvPbaa2YtkEzzn8MMiERE1PDU6vLklJQUhIaGAgA2btyIhx9+GM8//zz69euHQYMGmbM+IiIiasRqNaLi5uaGW7duAQC2bt2KYcOGAQCcnJxQUFBgvuqIiIioUavViMqwYcPw7LPPolu3bjh//jxGjx4NADhz5gxatmxpzvqIiIioEavViMqyZcsQGRmJGzdu4Ndff0WTJk0AADExMZg4caJZCyQiIqLGq1YjKl5eXli6dGml7QsWLKhzQURERERlajWisnnzZuzfv196vWzZMnTt2hWPP/44MjMzzVZcQxTRzFPuEoiIiGxGrYLK7NmzoVarAQDx8fF4/fXXMXr0aCQlJWHmzJlmLbChqXivHyIiIjKuVkElKSkJHTt2BAD8+uuvuP/++/HRRx9h2bJl2LRpU60KWbRoERQKBWbMmFGr99uK5we0lrsEIiIim1GroKJUKpGfnw8A2L59O4YPHw4A8PHxkUZaauLYsWP45ptv0Llz59qUY1PGdW2GxRO6yl0GERGRTahVUOnfvz9mzpyJhQsX4ujRoxgzZgwA4Pz58wgODq5RX7m5uZg0aRK+/fZbeHt7V9lWo9FArVbrPWzRg92ayV0CERGRTahVUFm6dCkcHBzwyy+/YPny5WjWrPQP76ZNmzBy5Mga9TVt2jSMGTMGQ4cOrbZtVFQUPD09pUdISEhtyrcKSgfeD5KIiKg6CiHkm925bt06fPjhhzh27BicnJwwaNAgdO3aFYsXLzbYXqPRQKPRSK/VajVCQkKQnZ0NDw8PC1VtHjFXMvHw8oP10vflRWPqpV8iIiJzUKvV8PT0NOnvd63WUQEArVaLjRs34ty5cwCA8PBwPPDAA7C3tzfp/SkpKXj11Vexbds2ODk5mfQelUoFlUpV25KtCi9TJiIiql6tgkpiYiJGjx6Na9euoX379gBKT8uEhIQgOjoabdq0qbaPmJgYZGRkoHv37tI2rVaLvXv3YunSpdBoNCaHHltkp5C7AiIiIutXq6Ayffp0tGnTBocPH4aPjw8A4NatW3jiiScwffp0REdHV9vHkCFDEB8fr7dt6tSpCAsLw5tvvtmgQwoAONhzjgoREVF1ahVU9uzZoxdSAKBJkyZYtGgR+vXrZ1If7u7u6NSpk942V1dXNGnSpNJ2ko8QAs+sPg4fVyU+e6SL3OUQEVEjU6t/1qtUKuTk5FTanpubC6VSWeeiyHokpOdg598Z+CXmqtylEBFRI1SrEZX7778fzz//PFasWIHevXsDAI4cOYIXXngBDzzwQK2L2b17d63fS/WjRHv3ojAhBBQKTq4hIiLLqdWIyldffYU2bdogMjISTk5OcHJyQt++fREaGmr00mKqrI2fKwDATVXri68sircpIiIiS6vVX0gvLy/89ttvSExMlC5P7tChA0JDQ81aXEO3ZcYAFJboMHLxXuRqSqTtjvYKFN8ZyZjStyVWHbwsU4X6dELADhxRISIiyzE5qFR3V+Rdu3ZJz7/44ovaV9SIONjbwc3A1T8x84bh15ir6NvGF+0D3bHmaDKKSnQyVKhPK0TtF94hIiKqBZP/7sTGxprUjnMYaq7ifzIPJ0dM7ddKer1j5kB8siUBf5y8blJ/5pxLUr4bnfxZiYiIGhmTg0r5EROyrBAfF/yjR7DBoOLt4ojM/GK9bZoSHZwczbMOjaLcqR4tJ6kQEZGFcdUxK7Nqai+D2x2NLGXr5265WwpotQwqRERkWQwqVmZQe3+D28uvZNs5+O59ghR1nNy641w67vtsN+JSsgzuL3/qhyMqRERkaQwqVsCUsFE+MLw2rB1eH9YOq5/ujZTM/Epta5Innll9HJdu5uHZ1ceqbavVMagQEZFlMahYgZfvK72s+4EuQUbbKCo8f2VIWwxs54f8Im2ltkXams96vZlbhF1/Z1TariuXenQcUSEiIgtjULECj/YMwZ7Zg7B4QlejbWpyEU+XBVvx9ob46htWMHVV5VGV8tmEIypERGRpDCpWokUTV9gZmTAL6AcGU+LCj0eSq9yfcjsfj3x9sNp+yocTBhUiIrI0rt9lI5p6OUvPOwV5VtHSNJO+O4Lk25Xnt1TEUz9ERCQnBhUb0czLGcse7w4He4XeJckfjY/AWzU4zbP9bDrScwpNCikAUH4QpYQjKkREZGE89WNDxnRuihHhgXrbHusVgg5NPTC+W7NK7f9z+ArUhcUoLje59tkfjuPtDaeNfsaTK44gu9wCcuVHUeKvZtelfCIiohpjULFxdnYKbHr1Xnz2SJdK++ZtPI3O87eix8JtJve378JN9Pt4p/RaV24UZfu59LoVS0REVEMMKg1EFfNwoS4swZxfT5ncV66mBGuOJOP+JfuQnqORtpdwZVoiIrIwBpUGorqbEK47llKj/t7aEI/T19T4bEuCtK2EdyUkIiILY1ChKhUU311QjpNpiYjI0hhUGpDqFoU7c73mk2HLd/lQ9+Aav5+IiKguGFQakOqWOblRbr6JqTLKvcddxavZiYjIshhUGpH/Hr5Sp/dzwTciIrI0BpVGZPu5yjcdrAkuoU9ERJbGoEImY0whIiJLY1Ahkwme+iEiIgtjUCGT8cwPERFZGoMKmYyTaYmIyNIYVMhkHFEhIiJLY1Ahk3GOChERWRqDCpmMp36IiMjSGFTIZLwnIRERWZqsQWX58uXo3LkzPDw84OHhgcjISGzatEnOkqgKHFEhIiJLkzWoBAcHY9GiRYiJicHx48dx3333Ydy4cThz5oycZZERzClERGRpst5lbuzYsXqvP/zwQyxfvhyHDx9GeHi4TFWRMRxRISIiS7Oa2+FqtVr8/PPPyMvLQ2RkpME2Go0GGs3du/mq1WpLlUfg5clERGR5sk+mjY+Ph5ubG1QqFV544QVs2LABHTt2NNg2KioKnp6e0iMkJMTC1TZuHFEhIiJLkz2otG/fHnFxcThy5AhefPFFTJ48GWfPnjXYdu7cucjOzpYeKSkpFq62ceM6KkREZGmyn/pRKpUIDQ0FAPTo0QPHjh3Dl19+iW+++aZSW5VKBZVKZekS6Q6e+iEiIkuTfUSlIp1OpzcPhawHT/0QEZGlyTqiMnfuXIwaNQrNmzdHTk4O1qxZg927d2PLli1ylkVGcESFiIgsTdagkpGRgaeeegqpqanw9PRE586dsWXLFgwbNkzOssiIdHWhWfopKNLCWWlvlr6IiKhhkzWorFixQs6Ppxr68fAVvDW6Q536WH3wMt77/Qy+mtgND3QJMlNlRETUUFndHBUyn64hXnAxMHLRqZlHrfp7rHfzupaE934vXXV4+trYOvdFREQNH4NKAxI7T/+U2YNdgxA/fwQ6NNUPJr+80BcKReX3V2xXkaezY53q05Ro9V6/sjaWlzwTEVGVGFQaEG9XZaVt9nYKBHroX9Lt5GiPFwa2qdS2la9Llf3X9aqffyw/pPf6j5PXcSOXV3gREZFxDCqN1JiIppW2dQ72wv8NbA0AmBzZotL+ml71o9UJFGt1+Hjz33h4+UHEX8uu1EZTrKtZp0RE1KjIvuAbmdfMYe3wxbbzRve/OqQtAKBTM0/8/nI/BHo64VZuEfaev4Gp/VrB0V6BJ+9pASGA1Yeu6L1XV4OksishA1NXHqu2naaEQYWIiIzjiEoDM/1OEAGAsljx5qgwKBTA0A7+mDH07v7OwV7wd3dCh6Ye+L+BbaB0sINCoUCwt4vBOSw1OfVjSkjRr5KIiKgyjqg0AmGBHjj/wSg42pueSw1lkvUnrqFjkAfu72y+y4rPpeYg1N/dbP0REVHDwhGVRqImIQWA3mXND3VrBgBIUxfi5TWxKLpzuqZYq0PLOdF4/aeTta7rlbWxSLqZV+v3ExFRw8agQgY1cVPhmyd74D/P9IZfhauGyk4BffTXOQDAryeu1umzNsZeq9P7iYio4eKpnwasrkuUjAgPBAAcvHjL4P4TyVl1+4A7eIkyEREZwxGVBsyjjgu0lTmZkmVwe4n27hU7X26/gMOXbtVqATc7AxN3iYiIAI6oNEhRD0Xg0MVbGNfVPJNez6aq9V6XZZHyc0v+ub30kuja3MPHwY55mYiIDONfiAZoYu/m+GpitxpPoDUmK79Y77W4c0lxfpG2Utvf467XuP8rtziZloiIDGNQoRqb+O/DRvcl36556NiVcKMu5RARUQPGoEI1dvJqdqUbDJY5n56LbWfTLVwRERE1VAwqVK1uzb0qbeu3aJfR9s/9cLweqyEiosaEQYWq5aaqPOf6Ji8pJiIiC2BQoWrV5B4/RERE5sSgQtXS8QbHREQkEwYVqpbgHY6JiEgmDCpERERktRhUqFoKcI17IiKSB4MKVYunfoiISC4MKlQtXvRDRERyYVChajGnEBGRXBhUqFoBHk5yl0BERI0UgwpV6937O9b7Z+h0HLchIqLKGFSoWn7uqnrp950xHaTnXP2WiIgMYVAh2TzaK0R6rmVQISIiAxhUSBajIwJhr7i7Pos5c4oQAkk38yAYfoiIbJ6sQSUqKgq9evWCu7s7/P398eCDDyIhIUHOkshCOgd7wa5cUNGacY7KU98fxeDPduPLHRfM1icREclD1qCyZ88eTJs2DYcPH8a2bdtQXFyM4cOHIy8vT86yqAYufjQa97T2Qf9QXyye0BU/PttHb3/Hph6V3jMiPABT+raEXbmfvtrOUSk/alKs1eHRrw9h34WbAIDF2xlUiIhsnYOcH75582a916tWrYK/vz9iYmIwYMAAmaqimrC3U2Dd85HS6wvpOXr7Hez1l99XKIBvnuwJoDRYlDH1Ds3LdiWiiasSj/VujqU7L+CHQ1ew/qW+CPZ2wR8nr+Po5du1PBIiIrJGVjVHJTs7GwDg4+NjcL9Go4FardZ7kGU81L2ZSe3Knc2Bi9Ie/uWuGHJ3csAPT/eWXpefo1JxRCUxIxf7LtyotO3TLQmYsz4eAPDZ1vPIyNHgk82lpwvzirSmHQwREdkMqwkqOp0OM2bMQL9+/dCpUyeDbaKiouDp6Sk9QkJCDLYj83t/nOHvpCL7cudzts8cCHu7u2Hk5LvDcW9bP+l1+VBT8aqfoV/swZMrjuLM9Wxpm7qw2OBn/n7yeqVQQ0REDYPVBJVp06bh9OnTWLdundE2c+fORXZ2tvRISUmxYIWNm5vKATHvDK22XcsmLhjYzg/3hfkjyMsZEc08pX12dhVPAylQtsnYgm9/p+YY3F7RkyuO4peYqya1JSIi2yHrHJUyL7/8Mv7880/s3bsXwcHBRtupVCqoVPWz+BhVr4mbCh2aeuBcqvFTbgqFAqvLnd55bkBrCAEMDvM32N5OoYBOCJhy0U/5QRdDlx6fTMmqvhMiIrIpsgYVIQReeeUVbNiwAbt370arVq3kLIdMUD6kmLJircrBHq8MaWt0v52dAtAJowu+lb8yqPzk22It10ghImoMZA0q06ZNw5o1a/Dbb7/B3d0daWlpAABPT084OzvLWRqZYHJkizr3UTahVqcTKCjSYvOZVAxspz/6UqLV4anvj+LgxVvStvyikjp/NhERWT9Zg8ry5csBAIMGDdLbvnLlSkyZMsXyBVGNvDgotM59SHNUhMCHf53Ffw8n6+1XQIEjSbf1QgoApKkL6/zZRERk/WQ/9UO2ZVzXIPwWdx0A9K7oqa2yCbY6AUSfSq20/+eYFBxIvFVp+8jF+0zqXwgBhaLudRIRkTys5qofsg1uKvNm27Jl9AuLtTAUWw2FlJpIuslVjomIbBmDCtWIOUZRyssuKF0bZdSX+5CVb3idlLpYsjPR7H0SEZHlMKhQjdjZ2GmUDbHX5C6BiIjqgEGFasTcQcWUS5yJiKjxYlChGgn1dzNrf0p7/ggSEZFxVrEyLdmOCb1CkJZdgHtaNzFLf3bMKUREVAUGFaoRezsFZg5vb77+bGzOCxERWRb/PUuyqnijQiIiovIYVEhWHFEhIqKqMKiQrMy9LgsRETUsDCokK2tY3v7IpVv46XiK3GUQEZEBnExLsrKGq5Mn/PswAKCtvxu6NfeWuRoiIirPCv5MUGOWX6SVuwTJtawCuUsgIqIKGFRIVpdu1M9NA12U9gCA8CCPKttpdXdvhejA+TJERFaHQYVs3sJx4ZW2vXxfKACgQ9Oqg8pb6+Ol53+cTNXbdy2rADmF5r9RIhERmY5BhWzamuf64MnIltj5+kC97WWjIzohDL1N8r9yk2ij4+8GldTsAvRbtBMR87easVoiIqopBhWyaX3b+AIAWvu5Ye1z96Ctvxti5w2Tbp5YVU5JVxdW2vbdvksAgJUHLuttLyrR4eDFmygstp45NUREjQGDClm9DS/1Nbj95HvD9V5HtmmCbTMHwttVKV32XNWIyqmr2ZW2fRB9DqevZePfey/d/ZyULCz44wwe//YIwuZtxv4LN2tzGEREVAsMKmT1uoZ4Gdzu6exo9D1l82J1VYyoGJs8m5lfpPf6dl4RfjySLL1+YsUR450SEZFZMaiQ1avNonB2JoyoGOv2yRVH9V5PXXWsxp9PRETmwaBCspo7Kkx6PqVvS7P1WzZYIqoIKnZ1WBW35wfbav1eIiIyHVemJVk9d29rdG/hDaW9HewUCqw6eFlv/9AO/nqve7TwRsyVTHQxcjqojDRHRQcUa3VIuZ2Pr3ZcwLTBoWgb4A6gbkHlZm5R9Y2IiKjOGFRIVnZ2CvRq6QMA0OkEhnbwR2GxDh+Nj8CWM2l4pGcwAGBohwBsP5eO+WPD0byJC9xUVf/oloWQpJt5aPv2Jmn73gs3cWLesDufXR9HRERE5sSgQlbDzk6B7yb3kl4/N6C19PybJ3sgK78ITdxUpvV1Z7AkIT1Hb/vtvCJk5xfD08WxTiMqQOlpJWu4qSIRUUPGoEI2wd5OYXJIAao+rbNo8zn0aOEDN5V9nWoq0uqgcqhbH0REVDUGFWqQqhroWHs0BWuPphhvYKKiEgYVIqL6xrP01CDV9bSOKXS6ev8IIqJGj0GFGiRLTJQtYVIhIqp3DCrUIClQ/yMq2mpueEhERHXHoEINkiUuxtFWtT4/ERGZBYMKNUiG7oxsbgwqRET1T9agsnfvXowdOxZBQUFQKBTYuHGjnOVQA6Iprv/5I+YKKjqdwOdbE7D3/A2z9EdE1JDIGlTy8vLQpUsXLFu2TM4yqAHKL9bW+2eYK6hsjLuGJTsT8dT3R6tvTETUyMi6jsqoUaMwatQok9trNBpoNBrptVqtro+yqAEoKDJPUOnZwhvHr2Qa3GeuoLLlTJpZ+iEiaohsao5KVFQUPD09pUdISIjcJZGVyi8qMUs/Xzza1ei+934/g4eXH0RRifHTTFqdgK6aQOOqvPvvhWItL3kmIirPpoLK3LlzkZ2dLT1SUuq+uig1TPlmGlFp3sTF6L6DF28h5komdidkSNvOXlcj/mo2gNJ7AU345hCGL94LTUnlenQ6gVFf7sP62GvSttxC8wQsIqKGwqaCikqlgoeHh96DyJCaLnFy8r3heGt0GP5eOFLatmRiN702rf1cDb43Oj4VQOloyOiv9mHs0v3I1ZRAJ4DjVzKRmJGLuOQspKsL8dmWBKgLiwEAcVezcC5V//RlCa8kIiLSw3v9UIPUrbmXFCBM4ensiOcHtAEAPNO/FRIzcjE6oqlem2Zezrh0I6/SezefLp1jUlBuAm9uYQnOXr8bQib8+7D0fOmuRDx5Twv85/CVSn3xkmciIn02NaJCZCp7u9qv+Dbv/o5Y/XRvqQ/fO3dtHtTe32B7zZ05Klrt3ZCRri7Eo98cMvoZhkIKAHy771KtaiYiaqhkHVHJzc1FYmKi9DopKQlxcXHw8fFB8+bNZayMbJ05b0q4eca9OHElE/eF+WPhn2eNtisud++f+GvZtfqsFfuTEOztjKn9WtXq/UREDY2sIyrHjx9Ht27d0K1b6VyAmTNnolu3bnj33XflLIsagPvCDI9+GBL1UESV+33dVBgeHggH+6p/XbafvTup9ott503+/IoW/GE8DBERNTayjqgMGjQIgjd2o3oQ4mP8ap3yLi8aY7bPfGtDvPT8dl6R2folImrMOEeFGo1Qfzd8P6VnnfowNvflpR9j6tQvEREZxqBCjcK+Nwbjz1f6476wAOx7Y3Ct+znxzjB4OjtW2v5XPFeXJSKqD7w8mRqF8qeCQnxcsO21AfB0qRw4quPp4og2fq44kZxlxurqV4lWh9C3N8HBToHEj0bLXQ4RUY1wRIUapbYB7vB3d6rVe+ty6bOl/J2mxvoTVyGEwJqjyQBKF5PjnDAisjUcUaEGS+VgB02JDq8Pa2fWfuUMKqevZWP2L6cwZ1QYBrbzQ66mBPM2nsaYiKYY2jFAajdy8T4AwKUbeXByvPvvkSKtDioHe4vXTURUWxxRoQbrzIIR2DVrEF4Z0tas/R6+dNtsfW2fOaBG7aesPIZzqWpM/v4oAGDpzkRsiL2GZ384jg/+PIsbORq99kt3JcKj3JwadQHvJUREtoUjKtRgOdjboZWv4fvzyO21oe0Q7O2MUH/3Svvuae1j9H03c/WDyNXMfOn5d/uT8N3+JPzyQqRem30XbkrPM3IK4eeuqm3ZREQWx6BCZGFbXxuAdgGVA0oZL2cl8jQlcLBXQOVgD02JFioHe+gM3Afoz1OV72f0j6/1l+7fdjZdep6aVYjwIM86VE9EZFkMKkQW5uOqNLhdaW+HIq0O2QXFCH9vCwBgcmQLrD50Bf/oEYwpfVvqtTcUXKrjbeSziYisFeeoEFnAy4ND4eumwrrn75FuclimX2gTAMAT97QAABy6dEvat/pQ6c0Lf4m5qnd3ZgCVXpvCTWW+f5v8FncNkVE7EJucabY+iYgq4ogKkQXMGtEes0a0N7hv5ZTeyMov0ptLYsgjFU7plI261ESxVld9IxO9ui4OAPDGL6ewbeZAs/VLRFQeR1SIzGBK35Y4MOc+g/sqTm6tSOlgB38PJ4tc9lxkxqBS5hbva0RE9YgjKkR18NKgNsjML8b8B8Ir7fvnhC4YExEEpYNp/x6wRFAp0dZ9wbcL6Tn47+Er0uvbeUXYd+EG7m3rV+e+iYgq4ogKUQ0N7eAPAHB3csAbI8MQ9VCEtO/Y20MBAH1a+WB8t2CTQwqAGrWtLVNP/QghcN/nu9FyTjSy8u+OmPx56jqG/XOvNHemzJMrjpq1TiKiMhxRIaqhT/7RBSv2X8I/eoRU2ufnrsLlRWNq1a+jvXlHVAa084OLoz3aBbjhq52JAEw79bPz73Q8veq49PqBpQew943ByM4vxstrYs1aIxFRdRhUiGrIx1WJ2SPCzN5vQZF5549cvZ2PnbMGAQCi41Nx8UYeNMU66HQCdkZOM5VodXohBQCSb+fj1NUsPLD0gFnrIyIyBYMKkZXIKzLv8vbu5ZbOb+KqwsUbeXjhvzEIC3SH0sEOp65m44+X+yMi2BOFxVrsv3AT7QMNL0THkEJEcmFQIbIS3Zt71/g9rf1cUVCkRWp2YaV975eb4Hv08t37E/2dliM9H7t0Py59NBrP/XC82sujqzIkzL/W761oV0IGDibexJsjw+Bgz2l0RI0dgwqRlQj1d6vxe9Y8ew8CPZ1wM1cDT2dHCAGkqwuRqylBh6YeJvXR+q2/avy5FdX1sudTV7Ow7Ww6pg0OxdSVxwAA3+5LqvV8HyJqOPjPFSIr8q9J3Q1uf7xPc4PbnRxLf4V93VRwtLeD0sEOIT4uJocUc6nLQnK5mhI8sPQAluxMxKJNf+vte/OXU7W6VQARNRwMKkRWZHB7w6dQPhofgYe7B2NgO/21Srxc5Lt3zyf/6Iyvn+gBACgquRtUsvKLcOzybQhRfcC4fDMPncqtsLvq4GW9/f87noJfT1ytcW1CiEqfn51fjG1n0826Oi8R1T+e+iGyIs5Ke+n5Bw92wsbYa3hlSFsAwOePdgEAFBRp8dXOCxjbOUiWGgHgwJz70MzLGbv+zgAAnEjOwvS1sfhqYjd0W7gNQgBLJnZD3zZN4OWi1FvM7odDl/H9/iS8Prw9Xllb/eXO2QXFNaotI6cQvT/cgU7NPPDHy/2hUCiQU1iMLu9vBVB6b6UDibewYnJPDOkQUKO+icjyFMKUf/ZYKbVaDU9PT2RnZ8PDw7JD3UT1JTOvCJoSHQI9nczW55iv9uHMdXWd+4l7dxgUUMDTpfSKov0XbuKJFUek/UPC/LHjTngpb/GErvBxVeKp72u3MFxVc1Xir2YjNiUT+UVa/BpzFRcycvX2R0/vjzFf7a9xv0RUf2ry95sjKkRWxtvV/KdzFj7YCQ/966DJ7V2V9sgrqnx35oqnmiqupmsopADAjP/FmfzZpriQnoMQHxc4Odpj7FLDIaSMsZBiDpviU/FB9Dn88Up/HEi8iTZ+bmjl66o3MkZEdcOgQtQImHrp86hOgRACWPRwBK5nFWL0V/vg767Ch+Mj4FluXZYylrp6+EaOBn7uKgDAxthrUvB5KrKF2T5DCAGFovJCeFqdwO28Igz5fDdeG9YOFzJyMWt4e/i4KvHijycAAN0XbtN7j7lHarLzizH/jzN4qHuzau+ptGJ/Ega190Mbv5pfRUZkjRhUiBqpI28NwbOrjyP+2t2F38rzclFW+wc3/mp2fZYo6fXhdpxeMALRp67jzV/jpe0/VLjnUE2oyo0GXc3MR/+PdwEoDT/vj+uE/KISzP75FKLjU6V2C/44CwBYcyQZozoFGu27RKur1RowhsLS0aTbePSbQwCADbHXqvxOXvxvDDadTsPCP00LS0IIREbtRJq6EJc+Gm10xeL6VHbLhi0zBhhdcJAaNwYVokboP8/0RoCHE/54pX+t/6gCQK9WPmapx13lgA/Gd8Kr6+KMtil/dZA5BHk5QwiBDbHXMPOnk9L2Hw5dwcbYa1AXVr1S8KbTaUb3xaVkoWfLmv23aTknukbtDSlf057zN/Bb7DUEejrhjZGVb/mg0wnkaEqQpi5dLHDxjguYOaxdrT63LGAJIZBXpIWbygHXsgqw5XQaHu4eLM1pMqTslg0jFu/VC1cX0nMw7J978eH4TpjUp4X0OYXFOuy7cAOD2vtb5EaeDV2JVoecwpJ6OeVsLpxMS9RIbD+bjmd/OI63Rofh+QFtzNZvXf/Alv/jVKzV4f0/zmJ0RFP4e6gw5PM9dS1PT48W3oi5kmnWPg0J9XfD9pkD9bZpSrR49OtDSFMXoq2/O/7zTG8oFAqcvpaNq5kFeOG/MSb1PalPc/Ru5YOeLX1QXKJDS19XvPjfGBRrBbafSzf4Hl83FQ7NvQ8JaTm4f4nhOTvDOgbg26d6VtqeXVCMLgu24vspPXFf2N2rpHI1JZjz6yn8eSq10nsqSooajauZBQj2dtYbMRJCoNVc/QUHh3cMwL+f6qn3c3V50RikZRfinqgdem13vD6w0ikuIQR2n7+BTkGe8HZx1AvhxVodHO3tsO/CDXi7KBHq74a07NIFEgM8nODnrqryXli1dT2rAMevZOKBLjW7Uu/KrTwEe7voXTVnbmX/nb97qieGdrTcVXA1+fvNoELUiOQXlcBFad6B1Cu38jD0iz0o1t79X8m/n+yB4eGlp0ZOJGfimz0XEZuchYwcDWaPaI8R4QEY+sVePNu/Fd65v6PRvmsbghI/HIWsgmI0cVUiR1MCD6e7/6K/dCMX95k5ABkyoJ0fnu7XEptPp2HdsZR6/7y6erh7MBQKYM6oMGh1Ah9Gn8P8B8Irzb8BgOPvDEXPD7ab3PfcUWGI2vQ3pg1ug1eHtIPSwQ7XswrwzZ6LWG3g9N13T/XEsz8cN9BTZRVPcU367jAOJN6q1ObzrQlYsjMRozoFVjkaZqzfXE0J3FTGf3eKtTo8vPwgmvu4YOnjdxduTLmdj3s/0T+taIr1J65i5k8n8Y8ewfjskS7S9sSMHOw4l4GoO4sjXvxodKUgU6zVIeV2PlobCHFxKVkI9XeD+53fifK/Y2O7BGHJxG5SH+tPXMX4bsH1MnLFoEJEFnUjR4Opq47i4e7BeLBrM5OGkYtKdNX+D3DBH2ew8sDlavuyUwA6AUwf0hbT7wut8lTW9awC9F20s9o+yXZsnNYPrXxdserAZfxz+3mz9FkWVL7fn4T3/zyrt29wez+M7BSoN19qwQPheO/3MwAAbxdHZOZXv/5P2WcIIfBzzFW88csp7HtjMDycHdFlwVap3ZxRYWgf4A5fN5XRq9xmDW+Hl+8rXXMp6q9z+GbvJTwV2QLbz6bjenYhvnysK77ecwnnUtVo7euKSzfzMKi9H3Yn3Ki2zgsfjoKjmWfO21xQWbZsGT799FOkpaWhS5cuWLJkCXr37l3t+xhUiBq2giItvtt3CZ9vu/vHx1VpjxKdQMIHo7D5dBrcnRzQt00Tg1fsGJJTWIyI+Vurb2hAn1Y+eG9sOEZ/tQ9A6b/87wvzx2dbE/Bgt2aYuvIYrmUV1Krvinq19Maxy6WnqQ7OuY/hqgF6pEcw3h/XCf/33xjsPV99YKjO9pkDoLS3x4BPd5mhurvGd2uGf07oatY+bSqo/O9//8NTTz2Fr7/+Gn369MHixYvx888/IyEhAf7+Vd+RlUGFqHHIyi9CbHIWBrTzM8v5+idXHMG+CzehUADzxnTEg92awafcKFCupgRrjlyBm8oRozoFYtXBy7BTKPDq0NJ/seZpSmCnUFRaL+XTLX9j2a6Lta5rw0t9EdGs9OorB3s76HQCmhIdnJX2eqcQ6uqHp3ujb5smKNYKzP/9DP533PpPTZG8zH3JvU0FlT59+qBXr15YunQpAECn0yEkJASvvPIK5syZU+V7GVSIyNp8GH0Wqw9egavKHk/c0wLqgmJpHsaCB8LxVGQLfLIlAQPb+SGimSd2/p2BwWH+Vc5/KC/mym2oHOzxYfQ5vD2mA45dvo2OTT3Qp3UTAMDpa9mIv5aNx3qF4HZeETycHasdth+7ZD/irxm/1PzPV/ojNbsQz/1wHD6uStzOKwJQOpm1ta+rNJqVri6Ek6M9FIrSFZa1OmHSfKDvnuqJJm5KdGvujdjkTIy/szjhgHZ+6BrsCV93FUZ1ago/dxVKtDoUFGurHRU7MOc+LN15AfsTb6JXCx+sj72mt3/d8/cgT1OCyDZNcOpqNvq08sH7f5416VRjY7Pm2T7oG+pr1j5tJqgUFRXBxcUFv/zyCx588EFp++TJk5GVlYXffvtNr71Go4FGo5Feq9VqhISEMKgQEZlJYbEW3+69hKf6tjS4yB9QevrM3cn4JcflnU/PQWJGLkaGB0qjQ+ZwNTMff8Wnol2AO3RCYO3RFAxq74e3N5zG1tcGoF1A7dZkSbmdjwOJN3EjR4PPt53Hf57pDRelPToHe+HKrTzczitG+0B3nEzJQhM3JdoHuKNEJ6DVCRxIvClNIq/I0BVOhigd7PBw92Zo4qrCjKFtcS2rABk5Gsz8KQ5rnr0H17MK4Oeu0psoe+TSLUz492GTjq+JqxK37gTNTs088OOz90jf8+28Iqgc7LD9XDpclQ6wswP6h/o17sm0169fR7NmzXDw4EFERkZK29944w3s2bMHR44c0Ws/f/58LFiwoFI/DCpERES2oyZBxaZWy5k7dy6ys7OlR0oKz6sSERE1ZLKuTOvr6wt7e3ukp+svUpSeno7AwMrDZyqVCiqVylLlERERkcxkHVFRKpXo0aMHduy4u9qgTqfDjh079E4FERERUeMk+71+Zs6cicmTJ6Nnz57o3bs3Fi9ejLy8PEydOlXu0oiIiEhmsgeVCRMm4MaNG3j33XeRlpaGrl27YvPmzQgIsNw9B4iIiMg6yb6OSl1wHRUiIiLb02Cv+iEiIqLGhUGFiIiIrBaDChEREVktBhUiIiKyWgwqREREZLUYVIiIiMhqMagQERGR1WJQISIiIqsl+8q0dVG2Vp1arZa5EiIiIjJV2d9tU9actemgkpOTAwAICQmRuRIiIiKqqZycHHh6elbZxqaX0NfpdLh+/Trc3d2hUCjM2rdarUZISAhSUlIa5PL8PD7b19CPkcdn2xr68QEN/xjr8/iEEMjJyUFQUBDs7KqehWLTIyp2dnYIDg6u18/w8PBokD+AZXh8tq+hHyOPz7Y19OMDGv4x1tfxVTeSUoaTaYmIiMhqMagQERGR1WJQMUKlUuG9996DSqWSu5R6weOzfQ39GHl8tq2hHx/Q8I/RWo7PpifTEhERUcPGERUiIiKyWgwqREREZLUYVIiIiMhqMagQERGR1WJQMWDZsmVo2bIlnJyc0KdPHxw9elTukgyaP38+FAqF3iMsLEzaX1hYiGnTpqFJkyZwc3PDww8/jPT0dL0+kpOTMWbMGLi4uMDf3x+zZ89GSUmJXpvdu3eje/fuUKlUCA0NxapVq+rlePbu3YuxY8ciKCgICoUCGzdu1NsvhMC7776Lpk2bwtnZGUOHDsWFCxf02ty+fRuTJk2Ch4cHvLy88MwzzyA3N1evzalTp3DvvffCyckJISEh+OSTTyrV8vPPPyMsLAxOTk6IiIjAX3/9Ve/HN2XKlErf58iRI23m+KKiotCrVy+4u7vD398fDz74IBISEvTaWPJn0ty/x6Yc36BBgyp9hy+88IJNHB8ALF++HJ07d5YW+IqMjMSmTZuk/bb8/ZlyfLb+/VW0aNEiKBQKzJgxQ9pmk9+hID3r1q0TSqVSfP/99+LMmTPiueeeE15eXiI9PV3u0ip57733RHh4uEhNTZUeN27ckPa/8MILIiQkROzYsUMcP35c3HPPPaJv377S/pKSEtGpUycxdOhQERsbK/766y/h6+sr5s6dK7W5dOmScHFxETNnzhRnz54VS5YsEfb29mLz5s1mP56//vpLvP3222L9+vUCgNiwYYPe/kWLFglPT0+xceNGcfLkSfHAAw+IVq1aiYKCAqnNyJEjRZcuXcThw4fFvn37RGhoqJg4caK0Pzs7WwQEBIhJkyaJ06dPi7Vr1wpnZ2fxzTffSG0OHDgg7O3txSeffCLOnj0r3nnnHeHo6Cji4+Pr9fgmT54sRo4cqfd93r59W6+NNR/fiBEjxMqVK8Xp06dFXFycGD16tGjevLnIzc2V2ljqZ7I+fo9NOb6BAweK5557Tu87zM7OtonjE0KI33//XURHR4vz58+LhIQE8dZbbwlHR0dx+vRpIYRtf3+mHJ+tf3/lHT16VLRs2VJ07txZvPrqq9J2W/wOGVQq6N27t5g2bZr0WqvViqCgIBEVFSVjVYa99957okuXLgb3ZWVlCUdHR/Hzzz9L286dOycAiEOHDgkhSv9w2tnZibS0NKnN8uXLhYeHh9BoNEIIId544w0RHh6u1/eECRPEiBEjzHw0+ir+IdfpdCIwMFB8+umn0rasrCyhUqnE2rVrhRBCnD17VgAQx44dk9ps2rRJKBQKce3aNSGEEP/617+Et7e3dHxCCPHmm2+K9u3bS68fffRRMWbMGL16+vTpI/7v//6v3o5PiNKgMm7cOKPvsaXjE0KIjIwMAUDs2bNHCGHZn0lL/B5XPD4hSv/Qlf+jUJEtHV8Zb29v8d133zW476/i8QnRcL6/nJwc0bZtW7Ft2za9Y7LV75CnfsopKipCTEwMhg4dKm2zs7PD0KFDcejQIRkrM+7ChQsICgpC69atMWnSJCQnJwMAYmJiUFxcrHcsYWFhaN68uXQshw4dQkREBAICAqQ2I0aMgFqtxpkzZ6Q25fsoa2Pp/x5JSUlIS0vTq8XT0xN9+vTROx4vLy/07NlTajN06FDY2dnhyJEjUpsBAwZAqVRKbUaMGIGEhARkZmZKbeQ65t27d8Pf3x/t27fHiy++iFu3bkn7bO34srOzAQA+Pj4ALPczaanf44rHV+bHH3+Er68vOnXqhLlz5yI/P1/aZ0vHp9VqsW7dOuTl5SEyMrLBfX8Vj69MQ/j+pk2bhjFjxlSqw1a/Q5u+KaG53bx5E1qtVu8LAoCAgAD8/fffMlVlXJ8+fbBq1Sq0b98eqampWLBgAe69916cPn0aaWlpUCqV8PLy0ntPQEAA0tLSAABpaWkGj7VsX1Vt1Go1CgoK4OzsXE9Hp6+sHkO1lK/V399fb7+DgwN8fHz02rRq1apSH2X7vL29jR5zWR/1ZeTIkXjooYfQqlUrXLx4EW+99RZGjRqFQ4cOwd7e3qaOT6fTYcaMGejXrx86deokfb4lfiYzMzPr/ffY0PEBwOOPP44WLVogKCgIp06dwptvvomEhASsX7/eZo4vPj4ekZGRKCwshJubGzZs2ICOHTsiLi6uQXx/xo4PaBjf37p163DixAkcO3as0j5b/R1kULFho0aNkp537twZffr0QYsWLfDTTz9ZLECQ+Tz22GPS84iICHTu3Blt2rTB7t27MWTIEBkrq7lp06bh9OnT2L9/v9yl1Atjx/f8889LzyMiItC0aVMMGTIEFy9eRJs2bSxdZq20b98ecXFxyM7Oxi+//ILJkydjz549cpdlNsaOr2PHjjb//aWkpODVV1/Ftm3b4OTkJHc5ZsNTP+X4+vrC3t6+0gzo9PR0BAYGylSV6by8vNCuXTskJiYiMDAQRUVFyMrK0mtT/lgCAwMNHmvZvqraeHh4WDQMldVT1XcTGBiIjIwMvf0lJSW4ffu2WY7Z0j8DrVu3hq+vLxITE6W6bOH4Xn75Zfz555/YtWsXgoODpe2W+pms799jY8dnSJ8+fQBA7zu09uNTKpUIDQ1Fjx49EBUVhS5duuDLL79sMN+fseMzxNa+v5iYGGRkZKB79+5wcHCAg4MD9uzZg6+++goODg4ICAiwye+QQaUcpVKJHj16YMeOHdI2nU6HHTt26J3DtFa5ubm4ePEimjZtih49esDR0VHvWBISEpCcnCwdS2RkJOLj4/X++G3btg0eHh7SUGhkZKReH2VtLP3fo1WrVggMDNSrRa1W48iRI3rHk5WVhZiYGKnNzp07odPppP/hREZGYu/evSguLpbabNu2De3bt4e3t7fUxhqO+erVq7h16xaaNm0q1WXNxyeEwMsvv4wNGzZg586dlU5BWepnsr5+j6s7PkPi4uIAQO87tNbjM0an00Gj0dj891fd8Rlia9/fkCFDEB8fj7i4OOnRs2dPTJo0SXpuk99hjaffNnDr1q0TKpVKrFq1Spw9e1Y8//zzwsvLS28GtLV4/fXXxe7du0VSUpI4cOCAGDp0qPD19RUZGRlCiNLL0Jo3by527twpjh8/LiIjI0VkZKT0/rLL0IYPHy7i4uLE5s2bhZ+fn8HL0GbPni3OnTsnli1bVm+XJ+fk5IjY2FgRGxsrAIgvvvhCxMbGiitXrgghSi9P9vLyEr/99ps4deqUGDdunMHLk7t16yaOHDki9u/fL9q2bat3+W5WVpYICAgQTz75pDh9+rRYt26dcHFxqXT5roODg/jss8/EuXPnxHvvvWeWy3erOr6cnBwxa9YscejQIZGUlCS2b98uunfvLtq2bSsKCwtt4vhefPFF4enpKXbv3q13eWd+fr7UxlI/k/Xxe1zd8SUmJor3339fHD9+XCQlJYnffvtNtG7dWgwYMMAmjk8IIebMmSP27NkjkpKSxKlTp8ScOXOEQqEQW7duFULY9vdX3fE1hO/PkIpXMtnid8igYsCSJUtE8+bNhVKpFL179xaHDx+WuySDJkyYIJo2bSqUSqVo1qyZmDBhgkhMTJT2FxQUiJdeekl4e3sLFxcXMX78eJGamqrXx+XLl8WoUaOEs7Oz8PX1Fa+//rooLi7Wa7Nr1y7RtWtXoVQqRevWrcXKlSvr5Xh27dolAFR6TJ48WQhReonyvHnzREBAgFCpVGLIkCEiISFBr49bt26JiRMnCjc3N+Hh4SGmTp0qcnJy9NqcPHlS9O/fX6hUKtGsWTOxaNGiSrX89NNPol27dkKpVIrw8HARHR1dr8eXn58vhg8fLvz8/ISjo6No0aKFeO655yr9Ulvz8Rk6NgB6Py+W/Jk09+9xdceXnJwsBgwYIHx8fIRKpRKhoaFi9uzZeutwWPPxCSHE008/LVq0aCGUSqXw8/MTQ4YMkUKKELb9/VV3fA3h+zOkYlCxxe9QIYQQNR+HISIiIqp/nKNCREREVotBhYiIiKwWgwoRERFZLQYVIiIisloMKkRERGS1GFSIiIjIajGoEBERkdViUCEiIiKrxaBCRHXSsmVLLF682OT2u3fvhkKhqHRjNCIiQ7gyLVEjM2jQIHTt2rVG4aIqN27cgKurK1xcXExqX1RUhNu3byMgIAAKhcIsNdTU7t27MXjwYGRmZsLLy0uWGojINA5yF0BE1kcIAa1WCweH6v8X4efnV6O+lUplnW9nT0SNB0/9EDUiU6ZMwZ49e/Dll19CoVBAoVDg8uXL0umYTZs2oUePHlCpVNi/fz8uXryIcePGISAgAG5ubujVqxe2b9+u12fFUz8KhQLfffcdxo8fDxcXF7Rt2xa///67tL/iqZ9Vq1bBy8sLW7ZsQYcOHeDm5oaRI0ciNTVVek9JSQmmT58OLy8vNGnSBG+++SYmT56MBx980OixXrlyBWPHjoW3tzdcXV0RHh6Ov/76C5cvX8bgwYMBAN7e3lAoFJgyZQqA0lvRR0VFoVWrVnB2dkaXLl3wyy+/VKo9OjoanTt3hpOTE+655x6cPn26lt8IEVWHQYWoEfnyyy8RGRmJ5557DqmpqUhNTUVISIi0f86cOVi0aBHOnTuHzp07Izc3F6NHj8aOHTsQGxuLkSNHYuzYsUhOTq7ycxYsWIBHH30Up06dwujRozFp0iTcvn3baPv8/Hx89tln+M9//oO9e/ciOTkZs2bNkvZ//PHH+PHHH7Fy5UocOHAAarUaGzdurLKGadOmQaPRYO/evYiPj8fHH38MNzc3hISE4NdffwUAJCQkIDU1FV9++SUAICoqCj/88AO+/vprnDlzBq+99hqeeOIJ7NmzR6/v2bNn4/PPP8exY8fg5+eHsWPHori4uMp6iKiWanXPZSKyWRVv+y5E6S3bAYiNGzdW+/7w8HCxZMkS6XWLFi3EP//5T+k1APHOO+9Ir3NzcwUAsWnTJr3PyszMFEIIsXLlSgFAJCYmSu9ZtmyZCAgIkF4HBASITz/9VHpdUlIimjdvLsaNG2e0zoiICDF//nyD+yrWIIQQhYWFwsXFRRw8eFCv7TPPPCMmTpyo975169ZJ+2/duiWcnZ3F//73P6O1EFHtcY4KEUl69uyp9zo3Nxfz589HdHQ0UlNTUVJSgoKCgmpHVDp37iw9d3V1hYeHBzIyMoy2d3FxQZs2baTXTZs2ldpnZ2cjPT0dvXv3lvbb29ujR48e0Ol0RvucPn06XnzxRWzduhVDhw7Fww8/rFdXRYmJicjPz8ewYcP0thcVFaFbt2562yIjI6XnPj4+aN++Pc6dO2e0byKqPQYVIpK4urrqvZ41axa2bduGzz77DKGhoXB2dsY//vEPFBUVVdmPo6Oj3muFQlFlqDDUXtTxgsRnn30WI0aMQHR0NLZu3YqoqCh8/vnneOWVVwy2z83NBQBER0ejWbNmevtUKlWdaiGi2uMcFaJGRqlUQqvVmtT2wIEDmDJlCsaPH4+IiAgEBgbi8uXL9VtgBZ6enggICMCxY8ekbVqtFidOnKj2vSEhIXjhhRewfv16vP766/j2228BlP43KOunTMeOHaFSqZCcnIzQ0FC9R/l5PABw+PBh6XlmZibOnz+PDh061Ok4icgwjqgQNTItW7bEkSNHcPnyZbi5ucHHx8do27Zt22L9+vUYO3YsFAoF5s2bV+XISH155ZVXEBUVhdDQUISFhWHJkiXIzMysch2WGTNmYNSoUWjXrh0yMzOxa9cuKUy0aNECCoUCf/75J0aPHg1nZ2e4u7tj1qxZeO2116DT6dC/f39kZ2fjwIED8PDwwOTJk6W+33//fTRp0gQBAQF4++234evrW+UVSERUexxRIWpkZs2aBXt7e3Ts2BF+fn5Vzjf54osv4O3tjb59+2Ls2LEYMWIEunfvbsFqS7355puYOHEinnrqKURGRsLNzQ0jRoyAk5OT0fdotVpMmzYNHTp0wMiRI9GuXTv861//AgA0a9YMCxYswJw5cxAQEICXX34ZALBw4ULMmzcPUVFR0vuio6PRqlUrvb4XLVqEV199FT169EBaWhr++OMPaZSGiMyLK9MSkc3R6XTo0KEDHn30USxcuNBin8sVbYksj6d+iMjqXblyBVu3bsXAgQOh0WiwdOlSJCUl4fHHH5e7NCKqZzz1Q0RWz87ODqtWrUKvXr3Qr18/xMfHY/v27ZzAStQI8NQPERERWS2OqBAREZHVYlAhIiIiq8WgQkRERFaLQYWIiIisFoMKERERWS0GFSIiIrJaDCpERERktRhUiIiIyGr9P192IIz/3yOtAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.optim import Adam\n",
    "import numpy as np\n",
    "\n",
    "# 训练序列到序列模型\n",
    "def train_seq2seq_mt(train_data, encoder, decoder, epochs=20,\\\n",
    "        learning_rate=1e-3):\n",
    "    # 准备模型和优化器\n",
    "    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)\n",
    "    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)\n",
    "    criterion = nn.NLLLoss()\n",
    "\n",
    "    encoder.train()\n",
    "    decoder.train()\n",
    "    encoder.zero_grad()\n",
    "    decoder.zero_grad()\n",
    "\n",
    "    step_losses = []\n",
    "    plot_losses = []\n",
    "    with trange(n_epochs, desc='epoch', ncols=60) as pbar:\n",
    "        for epoch in pbar:\n",
    "            np.random.shuffle(train_data)\n",
    "            for step, data in enumerate(train_data):\n",
    "                # 将源序列和目标序列转为 1 * seq_len 的tensor\n",
    "                # 这里为了简单实现，采用了批次大小为1，\n",
    "                # 当批次大小大于1时，编码器需要进行填充\n",
    "                # 并且返回最后一个非填充词的隐状态，\n",
    "                # 解码也需要进行相应的处理\n",
    "                input_ids, target_ids = data\n",
    "                input_tensor, target_tensor = \\\n",
    "                    torch.tensor(input_ids).unsqueeze(0),\\\n",
    "                    torch.tensor(target_ids).unsqueeze(0)\n",
    "\n",
    "                encoder_optimizer.zero_grad()\n",
    "                decoder_optimizer.zero_grad()\n",
    "\n",
    "                encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "                # 输入目标序列用于teacher forcing训练\n",
    "                decoder_outputs, _, _ = decoder(encoder_outputs,\\\n",
    "                    encoder_hidden, target_tensor)\n",
    "\n",
    "                loss = criterion(\n",
    "                    decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
    "                    target_tensor.view(-1)\n",
    "                )\n",
    "                pbar.set_description(f'epoch-{epoch}, '+\\\n",
    "                    f'loss={loss.item():.4f}')\n",
    "                step_losses.append(loss.item())\n",
    "                # 实际训练批次为1，训练损失波动过大\n",
    "                # 将多步损失求平均可以得到更平滑的训练曲线，便于观察\n",
    "                plot_losses.append(np.mean(step_losses[-32:]))\n",
    "                loss.backward()\n",
    "\n",
    "                encoder_optimizer.step()\n",
    "                decoder_optimizer.step()\n",
    "\n",
    "    plot_losses = np.array(plot_losses)\n",
    "    plt.plot(range(len(plot_losses)), plot_losses)\n",
    "    plt.xlabel('training step')\n",
    "    plt.ylabel('loss')\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "hidden_size = 128\n",
    "n_epochs = 20\n",
    "learning_rate = 1e-3\n",
    "\n",
    "encoder = RNNEncoder(input_lang.n_words, hidden_size)\n",
    "decoder = AttnRNNDecoder(output_lang.n_words, hidden_size)\n",
    "\n",
    "train_seq2seq_mt(train_data, encoder, decoder, n_epochs, learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be47e115",
   "metadata": {},
   "source": [
    "下面实现贪心搜索解码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "678192b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 人 工 智 能 计 算 思 维 启 蒙 教 程\n",
      "target： a .t .a . calculating thoughts enlightenment\n",
      "pred： a .t .a . calculating thoughts enlightenment , thinking about a girl .\n",
      "\n",
      "input： 从 零 开 始 ： p h o t o s h o p 工 具 详 解 与 实 战\n",
      "target： from scratch: photoshop tool detailed and operational\n",
      "pred： spring cloud microservice distributed architecture techniques and combat\n",
      "\n",
      "input： 大 学 生 职 业 生 涯 规 划 与 就 业 创 业 指 导 （ 微 课 版 ）\n",
      "target： career planning and entrepreneurship guidance for university students (micro-curricular version)\n",
      "pred： career planning and entrepreneurship guidance for university students (micro-curricular version)\n",
      "\n",
      "input： 职 业 发 展 与 就 业 创 业 指 导 （ 慕 课 版 ）\n",
      "target： career development and entrepreneurship for employment (curriculum version)\n",
      "pred： career development and entrepreneurship for employment (curriculum version)\n",
      "\n",
      "input： m i c r o s o f t a z u r e 机 器 学 习 和 预 测 分 析\n",
      "target： microsoft azure machine learning and forecasting analysis\n",
      "pred： microsoft and fluent learning efficiently analysis , operational combat and combat techniques\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "def greedy_decode(encoder, decoder, sentence, input_lang, output_lang):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "        \n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        decoder_outputs, decoder_hidden, decoder_attn = \\\n",
    "            decoder(encoder_outputs, encoder_hidden)\n",
    "        \n",
    "        # 取出每一步预测概率最大的词\n",
    "        _, topi = decoder_outputs.topk(1)\n",
    "        \n",
    "        decoded_ids = []\n",
    "        for idx in topi.squeeze():\n",
    "            if idx.item() == EOS_token:\n",
    "                break\n",
    "            decoded_ids.append(idx.item())\n",
    "    return output_lang.ids2sent(decoded_ids), decoder_attn\n",
    "            \n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence, _ = greedy_decode(encoder, decoder, pair[0],\n",
    "        input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e38320f0",
   "metadata": {},
   "source": [
    "\n",
    "接下来使用束搜索解码来验证模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3496efa3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 超 简 单 的 简 谱 图 解 教 程\n",
      "target： a very simple pamphlet pedagogic pedagogy .\n",
      "pred： a very simple pamphlet pedagogic pedagogy .\n",
      "\n",
      "input： 想 象 力 构 图 与 创 作 思 维\n",
      "target： imagination , mapping and creative thinking .\n",
      "pred： imagination , mapping and creative thinking .\n",
      "\n",
      "input： 税 务 会 计 理 论 与 实 务 （ 第 2 版 ）\n",
      "target： tax accounting theory and practice (version 2)\n",
      "pred： tax accounting theory and practice (version 2)\n",
      "\n",
      "input： 线 描 古 风 绘 古 风 仙 侠 人 物 绘 制 一 学 就 会\n",
      "target： it's an ancient painting of the art of the ancient cage man , as soon as he learns it .\n",
      "pred： it's an ancient painting of the art of the ancient cage man , as soon as he learns it .\n",
      "\n",
      "input： 客 户 服 务 — — 策 略 、 技 术 、 管 理\n",
      "target： client services — strategy , technology , management\n",
      "pred： client services — strategy , technology , technology , innovation , global technology thinking , methods , techniques , technology , innovation , global technology thinking , technology and innovation\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 定义容器类用于管理所有的候选结果\n",
    "class BeamHypotheses:\n",
    "    def __init__(self, num_beams, max_length):\n",
    "        self.max_length = max_length\n",
    "        self.num_beams = num_beams\n",
    "        self.beams = []\n",
    "        self.worst_score = 1e9\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.beams)\n",
    "    \n",
    "    # 添加一个候选结果，更新最差得分\n",
    "    def add(self, sum_logprobs, hyp, hidden):\n",
    "        score = sum_logprobs / max(len(hyp), 1)\n",
    "        if len(self) < self.num_beams or score > self.worst_score:\n",
    "            # 可更新的情况：数量未饱和或超过最差得分\n",
    "            self.beams.append((score, hyp, hidden))\n",
    "            if len(self) > self.num_beams:\n",
    "                # 数量饱和需要删掉一个最差的\n",
    "                sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                    (s, _, _) in enumerate(self.beams)])\n",
    "                del self.beams[sorted_scores[0][1]]\n",
    "                self.worst_score = sorted_scores[1][0]\n",
    "            else:\n",
    "                self.worst_score = min(score, self.worst_score)\n",
    "    \n",
    "    # 取出一个未停止的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        for i, (s, hyp, hid) in enumerate(self.beams):\n",
    "            # 未停止的候选结果需满足：长度小于最大解码长度；不以<eos>结束\n",
    "            if len(hyp) < self.max_length and (len(hyp) == 0\\\n",
    "                    or hyp[-1] != EOS_token):\n",
    "                del self.beams[i]\n",
    "                if len(self) > 0:\n",
    "                    sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                        (s, _, _) in enumerate(self.beams)])\n",
    "                    self.worst_score = sorted_scores[0][0]\n",
    "                else:\n",
    "                    self.worst_score = 1e9\n",
    "                return True, (s, hyp, hid)\n",
    "        return False, None\n",
    "    \n",
    "    # 取出分数最高的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop_best(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        sorted_scores = sorted([(s, idx) for idx, (s, _, _)\\\n",
    "            in enumerate(self.beams)])\n",
    "        return True, self.beams[sorted_scores[-1][1]]\n",
    "\n",
    "\n",
    "def beam_search_decode(encoder, decoder, sentence, input_lang,\n",
    "        output_lang, num_beams=3):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "\n",
    "        # 在容器中插入一个空的候选结果\n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        init_hyp = []\n",
    "        hypotheses = BeamHypotheses(num_beams, MAX_LENGTH)\n",
    "        hypotheses.add(0, init_hyp, encoder_hidden)\n",
    "\n",
    "        while True:\n",
    "            # 每次取出一个未停止的候选结果\n",
    "            flag, item = hypotheses.pop()\n",
    "            if not flag:\n",
    "                break\n",
    "                \n",
    "            score, hyp, decoder_hidden = item\n",
    "            \n",
    "            # 当前解码器输入\n",
    "            if len(hyp) > 0:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(hyp[-1])\n",
    "            else:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(SOS_token)\n",
    "\n",
    "            # 解码一步\n",
    "            decoder_output, decoder_hidden, _ = decoder.forward_step(\n",
    "                decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "\n",
    "            # 从输出分布中取出前k个结果\n",
    "            topk_values, topk_ids = decoder_output.topk(num_beams)\n",
    "            # 生成并添加新的候选结果到容器\n",
    "            for logp, token_id in zip(topk_values.squeeze(),\\\n",
    "                    topk_ids.squeeze()):\n",
    "                sum_logprobs = score * len(hyp) + logp.item()\n",
    "                new_hyp = hyp + [token_id.item()]\n",
    "                hypotheses.add(sum_logprobs, new_hyp, decoder_hidden)\n",
    "\n",
    "        flag, item = hypotheses.pop_best()\n",
    "        if flag:\n",
    "            hyp = item[1]\n",
    "            if hyp[-1] == EOS_token:\n",
    "                del hyp[-1]\n",
    "            return output_lang.ids2sent(hyp)\n",
    "        else:\n",
    "            return ''\n",
    "\n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence = beam_search_decode(encoder, decoder,\\\n",
    "        pair[0], input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12fc379e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
